Add TeLU activation functions telu and telu_fast#622
Add TeLU activation functions telu and telu_fast#622zengmao wants to merge 4 commits intoFluxML:masterfrom
telu and telu_fast#622Conversation
mcabbott
left a comment
There was a problem hiding this comment.
Thanks! Some quick comments from a first pass...
| telu_fast(x) | ||
|
|
||
| This is faster but less accruate version of `telu`. This function is associated with a hard-coded derivative, | ||
| `deriv_telu_fast`, which is faster but less accurate that `deriv_telu`. |
There was a problem hiding this comment.
Is this less meaningfully less accurate? In the tests there should be some functions for measuring error, countepsfrom and friends.
My guess is that for NN purposes, we will only want the fast version, and probably @fastmath x * tanh_fast(exp(x)) to speed up exp too.
There was a problem hiding this comment.
In my gist, but translated to this notation -- there is hardly any accuracy change:
julia> worst_eps(telu_fast, telu, -5:0.01f0:5) # comparing to bigfloat
3
julia> worst_eps(telu, telu, -5:0.01f0:5)
2
|
Thanks for the comments! I've updated the code to reuse The fast derivative |
|
Unfortunately, the update broke the test, since I have a type-dependent small- P.S. Never mind, the AD trouble is gone after I use |
|
I timed everything and tried to simplify a bit here: https://gist.github.com/mcabbott/8fb03f175ee4e0c29ef4a7044dc19a85 Since then you've simplified this too, good going. I still wish it were shorter!
Here's what I think the entire derivative code could look like. (I've inlined the "main path" just to have fewer symbols with confusingly similar names around.) function deriv_telu(x::Real, _)
# Adapted from the Discourse post, to avoid bad cancellations: <https://discourse.julialang.org/t/how-to-compute-tanhexp-telu-function-accurately/124464/7>
exp_x = exp(x)
tanh(exp_x) + 4x / (exp(exp_x - x/2) + exp(-exp_x - x/2))^2
end
function deriv_telu(x::T, Ω = telu(x)) where {T <: Union{Float16, Float32, Float64}}
# Main path, re-using forward pass:
tanh_exp_x = Ω / x
sech_exp_x_squared = 1 - tanh_exp_x^2
main = @fastmath tanh_exp_x + x * exp(x) * sech_exp_x_squared
# That's badly behaved at zero, switch to a taylor series:
taylor = _deriv_telu_taylor(x)
# It's also badly behaved at large x, switch to 1!
ifelse(abs(x) < T(0.01), taylor, # this works just as well
ifelse(x > 4, one(x), main)) # as does this
end
# Taylor coefficients are (tanh(1), 8*exp(1)^2 / (1+exp(1)^2)^2)
_deriv_telu_taylor(x::T) where T = convert(T, evalpoly(x, (0.7615941559557649, 0.8399486832280524)))
_deriv_telu_taylor(x::Float32) = evalpoly(x, (0.7615942f0, 0.83994865f0))
_deriv_telu_taylor(x::Float16) = evalpoly(x, (Float16(0.7617), Float16(0.84)))In fact, the whole first (The more exact formula could be kept in the tests, to compare there to this piecewise thing) |
|
On one hand, the P.S. maybe adding a second-order term to the Taylor expansion will guarantee sufficient accuracy for any practical NN purpose. |
Ah, sorry I see some larger errors which I missed. I wonder a little bit if we should just take the hit & simplify to this. Counting forward+reverse it's only 30% slower (and may allow sufficiently smart AD not to keep the array containing function deriv_telu_2(x::Real)
# This version does not re-use forward pass, as doing so has 0/0 problems at x=0:
exp_x = @fastmath exp(x)
tanh_exp_x = tanh_fast(exp_x)
main = tanh_exp_x + x * exp_x * (1 - tanh_exp_x^2)
# That gives NaN at large x, where telu(x) is just relu(x) anyway:
ifelse(x > 4, one(float(x)), main)
end |
|
The new code which recomputes |
|
I can also update the code and the accuracy tests ( |
This PR adds the TeLU activation function advocated by a recent paper, following discussions on Julia Discourse.
teluandtelu_fasthave been added toactivation.jl. The latter is slightly faster and usestanh_fastwhile sacrificing accuracy a bit. The hard-coded derivatives arederiv_teluandderiv_telu_fast, respectively. The accuracy gap between the derivative functions is more significant, asderiv_telure-organizes the expression to avoid numerical instabilities.